# test_tts_service_v2.py
# Description: Tests for the V2 TTS Service architecture
#
# Imports
import asyncio
import base64
import os
import tempfile
from unittest.mock import AsyncMock, MagicMock, patch, Mock
import pytest
import numpy as np
from typing import Dict, Any

# Local Imports
from tldw_Server_API.app.api.v1.schemas.audio_schemas import OpenAISpeechRequest
from tldw_Server_API.app.core.TTS.tts_service_v2 import TTSServiceV2
from tldw_Server_API.app.core.TTS.adapter_registry import TTSAdapterRegistry, TTSProvider, TTSAdapterFactory
from tldw_Server_API.app.core.TTS.adapters.base import (
    TTSAdapter,
    TTSCapabilities,
    TTSRequest,
    TTSResponse,
    AudioFormat,
    ProviderStatus
)
from tldw_Server_API.app.core.TTS.circuit_breaker import CircuitBreaker, CircuitState
from tldw_Server_API.app.core.TTS.audio_utils import AudioProcessor, process_voice_reference
from tldw_Server_API.app.core.TTS.tts_exceptions import TTSGenerationError, TTSProviderError

#######################################################################################################################
#
# Test Classes


class MockAdapter(TTSAdapter):
    """Mock adapter for testing"""

    def __init__(self, config: Dict[str, Any] = None):
        # The parent expects provider_config, but registry passes config
        provider_config = config or {}
        super().__init__(provider_config)
        self.initialized = False
        self.generate_called = False
        self.provider_id = provider_config.get("name", "mock")

    async def initialize(self) -> bool:
        self.initialized = True
        self._initialized = True
        self._status = ProviderStatus.AVAILABLE
        self._capabilities = await self.get_capabilities()
        return True

    async def generate(self, request: TTSRequest) -> TTSResponse:
        self.generate_called = True
        target_format = request.format if isinstance(request.format, AudioFormat) else AudioFormat.MP3

        if request.stream:
            async def _stream():
                for chunk in (b"mock ", b"stream ", b"data"):
                    yield chunk

            return TTSResponse(
                audio_stream=_stream(),
                format=target_format,
                provider=self.provider_id
            )

        return TTSResponse(
            audio_data=b"mock audio data",
            format=target_format,
            provider=self.provider_id
        )

    async def get_capabilities(self) -> TTSCapabilities:
        from tldw_Server_API.app.core.TTS.adapters.base import VoiceInfo, AudioFormat
        return TTSCapabilities(
            provider_name="mock",
            supports_streaming=True,
            supports_voice_cloning=False,
            supported_languages={"en"},
            supported_formats={AudioFormat.MP3, AudioFormat.WAV},
            max_text_length=5000,
            supported_voices=[
                VoiceInfo(id="voice1", name="Voice 1"),
                VoiceInfo(id="voice2", name="Voice 2")
            ]
        )

    async def list_voices(self):
        return ["voice1", "voice2"]

    async def cleanup(self):
        self.initialized = False


class TestTTSServiceV2:
    """Tests for TTSServiceV2"""

    import pytest_asyncio
    @pytest_asyncio.fixture
    async def service(self):
        """Create a test service instance"""
        config = {
            "providers": {
                "mock": {
                    "enabled": True,
                    "priority": 1
                }
            },
            "fallback": {
                "enabled": True,
                "max_attempts": 2
            }
        }

        # Create a mock factory
        factory = MagicMock(spec=TTSAdapterFactory)
        factory.create_adapter = AsyncMock(return_value=MockAdapter(config.get("mock", {})))
        factory.get_adapter = AsyncMock(return_value=MockAdapter(config.get("mock", {})))
        factory.list_adapters = AsyncMock(return_value=[TTSProvider.MOCK])
        factory.get_adapter_by_model = AsyncMock(return_value=MockAdapter({"name": "mock"}))
        factory.get_best_adapter = AsyncMock(return_value=MockAdapter({"name": "mock"}))
        factory.registry = MagicMock(spec=TTSAdapterRegistry)
        factory.registry.get_all_capabilities = AsyncMock(return_value={})
        factory.registry.get_adapter = AsyncMock(return_value=MockAdapter({"name": "mock"}))

        service = TTSServiceV2(factory)

        class MetricsStub:
            def register_metric(self, *args, **kwargs):
                return None

            def set_gauge(self, *args, **kwargs):
                return None

            def increment(self, *args, **kwargs):
                return None

            def observe(self, *args, **kwargs):
                return None

            def gauge_add(self, *args, **kwargs):
                return None

        service.metrics = MetricsStub()

        return service

    @pytest.mark.asyncio
    async def test_service_initialization(self, service):
        """Test service initializes properly"""
        assert service is not None
        assert hasattr(service, 'factory')
        # Check that service is properly initialized
        assert service.factory is not None

    @pytest.mark.asyncio
    async def test_generate_speech(self, service):
        """Test basic speech generation"""
        # Use OpenAISpeechRequest which the service expects
        request = OpenAISpeechRequest(
            input="Test text",
            model="mock",
            voice="voice1",
            response_format="mp3"
        )

        request.stream = False

        mock_adapter = MockAdapter({"name": "mock"})
        await mock_adapter.initialize()
        service.factory.get_adapter_by_model = AsyncMock(return_value=mock_adapter)

        chunks = [chunk async for chunk in service.generate_speech(request)]
        payload = b"".join(chunks)

        assert payload == b"mock audio data"

    @pytest.mark.asyncio
    async def test_generate_stream(self, service):
        """Test streaming generation"""
        # Use OpenAISpeechRequest
        request = OpenAISpeechRequest(
            input="Test text",
            model="mock",
            voice="voice1",
            response_format="mp3"
        )

        mock_adapter = MockAdapter({"name": "mock"})
        await mock_adapter.initialize()
        service.factory.get_adapter_by_model = AsyncMock(return_value=mock_adapter)

        chunks = [chunk async for chunk in service.generate_speech(request)]

        assert len(chunks) == 3
        assert b"".join(chunks) == b"mock stream data"

    @pytest.mark.asyncio
    async def test_get_capabilities(self, service):
        """Test getting provider capabilities"""
        mock_adapter = MockAdapter({"name": "mock"})
        await mock_adapter.initialize()
        mock_caps = await mock_adapter.get_capabilities()
        service.factory.registry.get_all_capabilities = AsyncMock(return_value={
            TTSProvider.MOCK: mock_caps
        })

        caps = await service.get_capabilities()

        assert caps is not None
        assert "mock" in caps
        provider_caps = caps["mock"]
        assert provider_caps["supports_streaming"] is True
        assert "en" in provider_caps["languages"]
        assert "mp3" in provider_caps["formats"]

    @pytest.mark.asyncio
    async def test_list_providers(self, service):
        """Test listing available providers"""
        # Mock the factory's get_status
        service.factory.get_status = MagicMock(return_value={
            "providers": {
                "mock": "available"
            }
        })

        # The service may not have list_providers, check what it has
        status = service.factory.get_status()

        assert "providers" in status
        assert "mock" in status["providers"]
        assert status["providers"]["mock"] == "available"

    @pytest.mark.asyncio
    async def test_fallback_mechanism(self):
        """Test fallback to another provider on failure"""
        config = {
            "providers": {
                "failing": {
                    "enabled": True,
                    "priority": 1
                },
                "working": {
                    "enabled": True,
                    "priority": 2
                }
            },
            "fallback": {
                "enabled": True,
                "max_attempts": 3
            }
        }

        # Create failing adapter
        failing_adapter = MockAdapter({"name": "failing"})

        # Create working adapter
        working_adapter = MockAdapter({"name": "working"})

        # Create mock factory that returns appropriate adapters
        factory = MagicMock(spec=TTSAdapterFactory)
        factory.create_adapter = AsyncMock(side_effect=lambda name, *args:
                                          failing_adapter if name == "failing" else working_adapter)
        factory.get_adapter = AsyncMock(side_effect=lambda name:
                                       failing_adapter if name == "failing" else working_adapter)
        factory.list_adapters = AsyncMock(return_value=["failing", "working"])

        request = OpenAISpeechRequest(
            input="Test text",
            model="failing",
            voice="voice1",
            response_format="mp3"
        )
        request.stream = False

        failing_error = TTSProviderError("Provider failed", provider="failing")
        failing_adapter.generate = AsyncMock(side_effect=failing_error)
        await failing_adapter.initialize()
        await working_adapter.initialize()

        factory.get_adapter_by_model = AsyncMock(return_value=failing_adapter)
        factory.get_best_adapter = AsyncMock(return_value=working_adapter)
        factory.registry = MagicMock(spec=TTSAdapterRegistry)
        factory.registry.get_adapter = AsyncMock(return_value=working_adapter)

        service = TTSServiceV2(factory)

        class MetricsStub:
            def register_metric(self, *args, **kwargs):
                return None

            def set_gauge(self, *args, **kwargs):
                return None

            def increment(self, *args, **kwargs):
                return None

            def observe(self, *args, **kwargs):
                return None

            def gauge_add(self, *args, **kwargs):
                return None

        service.metrics = MetricsStub()

        chunks = [chunk async for chunk in service.generate_speech(request)]
        payload = b"".join(chunks)

        assert payload == b"mock audio data"

    @pytest.mark.asyncio
    async def test_circuit_breaker(self, service):
        """Test circuit breaker functionality"""
        # Get circuit breaker for mock provider
        cb = service.circuit_manager.get_circuit("mock") if getattr(service, 'circuit_manager', None) else None

        if cb:
            assert cb.state == CircuitState.CLOSED

            # Simulate failures (use private method)
            for _ in range(5):
                cb._record_failure()

            # Circuit should be open
            assert cb.state == CircuitState.OPEN

            # Should not allow calls
            assert not cb.allow_request()


class TestAudioUtils:
    """Tests for audio processing utilities"""

    def test_audio_processor_initialization(self):
        """Test AudioProcessor initialization"""
        processor = AudioProcessor()
        assert processor is not None
        assert "higgs" in processor.PROVIDER_REQUIREMENTS
        assert "chatterbox" in processor.PROVIDER_REQUIREMENTS
        assert "vibevoice" in processor.PROVIDER_REQUIREMENTS

    def test_validate_audio(self):
        """Test audio validation"""
        processor = AudioProcessor()

        # Create a simple WAV header (minimal valid WAV)
        wav_header = b'RIFF\x24\x00\x00\x00WAVEfmt \x10\x00\x00\x00\x01\x00\x01\x00"V\x00\x00D\xac\x00\x00\x02\x00\x10\x00data\x00\x00\x00\x00'
        # Add some audio data (5 seconds worth at 22050Hz, 16-bit mono)
        audio_data = b'\x00\x00' * (22050 * 5)
        valid_wav = wav_header + audio_data

        # Test validation for Higgs provider
        is_valid, msg, info = processor.validate_audio(valid_wav, "higgs", check_duration=False)
        # Note: Without proper audio libs, this might fail, but at least test the method exists
        assert isinstance(is_valid, bool)
        if msg:
            assert isinstance(msg, str)

    def test_process_voice_reference(self):
        """Test voice reference processing"""
        # Create a base64 encoded simple audio
        audio_data = b'RIFF\x24\x00\x00\x00WAVEfmt \x10\x00\x00\x00\x01\x00\x01\x00"V\x00\x00D\xac\x00\x00\x02\x00\x10\x00data\x00\x00\x00\x00' + b'\x00\x00' * 1000
        base64_audio = base64.b64encode(audio_data).decode('utf-8')

        # Process
        processed, error = process_voice_reference(
            base64_audio,  # Pass base64 string, not raw bytes
            provider="higgs",
            validate=False,  # Skip validation for test
            convert=False
        )

        # Check result
        assert processed is not None or error is not None
        if error:
            assert isinstance(error, str)

    @pytest.mark.asyncio
    async def test_base64_encoding(self):
        """Test base64 encoding/decoding for voice references"""
        original_data = b"test audio data"

        # Encode
        encoded = base64.b64encode(original_data).decode('utf-8')

        # Decode
        decoded = base64.b64decode(encoded)

        assert decoded == original_data


class TestCircuitBreaker:
    """Tests for Circuit Breaker"""

    def test_circuit_breaker_initialization(self):
        """Test circuit breaker initialization"""
        cb = CircuitBreaker(
            provider_name="test",
            failure_threshold=3,
            recovery_timeout=10
        )

        assert cb.state == CircuitState.CLOSED
        # Check status instead of direct attribute
        status = cb.get_status()
        assert status["stats"]["failure_count"] == 0

    @pytest.mark.asyncio
    async def test_circuit_opens_after_threshold(self):
        """Test circuit opens after failure threshold"""
        cb = CircuitBreaker(provider_name="test", failure_threshold=3)

        # Define a failing function
        async def failing_func():
            raise Exception("Test failure")

        # Record failures through the call method
        for _ in range(3):
            try:
                await cb.call(failing_func)
            except Exception:
                pass  # Expected to fail

        assert cb.state == CircuitState.OPEN
        assert not cb.is_available

    @pytest.mark.asyncio
    async def test_circuit_recovery(self):
        """Test circuit recovery to half-open state"""
        cb = CircuitBreaker(
            provider_name="test_provider",
            failure_threshold=2,
            recovery_timeout=0.1,  # 100ms for testing
            success_threshold=2
        )

        # Define functions for testing
        async def failing_func():
            raise Exception("Test failure")

        async def success_func():
            return "success"

        # Open the circuit
        for _ in range(2):
            try:
                await cb.call(failing_func)
            except Exception:
                pass
        assert cb.state == CircuitState.OPEN

        # Wait for recovery timeout
        await asyncio.sleep(0.2)

        # Should transition to half-open on next check
        assert cb.is_available  # This triggers transition
        assert cb.state == CircuitState.HALF_OPEN

        # Success should close circuit after success_threshold
        for _ in range(2):
            await cb.call(success_func)
        assert cb.state == CircuitState.CLOSED

    @pytest.mark.asyncio
    async def test_circuit_reopens_on_failure_in_half_open(self):
        """Test circuit reopens if failure occurs in half-open state"""
        cb = CircuitBreaker(
            provider_name="test",
            failure_threshold=2,
            recovery_timeout=0.1
        )

        # Define a failing function
        async def failing_func():
            raise Exception("Test failure")

        # Open circuit
        for _ in range(2):
            try:
                await cb.call(failing_func)
            except Exception:
                pass

        # Wait and transition to half-open
        await asyncio.sleep(0.2)
        assert cb.is_available  # Triggers transition
        assert cb.state == CircuitState.HALF_OPEN

        # Failure in half-open should reopen
        try:
            await cb.call(failing_func)
        except Exception:
            pass  # Expected
        assert cb.state == CircuitState.OPEN


class TestAdapterRegistry:
    """Tests for Adapter Registry"""

    def test_registry_initialization(self):
        """Test registry initialization"""
        config = {
            "providers": {
                "openai": {"enabled": True},
                "kokoro": {"enabled": False}
            }
        }

        registry = TTSAdapterRegistry(config)
        assert registry is not None
        assert registry.config == config

    @pytest.mark.asyncio
    async def test_register_adapter(self):
        """Test registering an adapter"""
        registry = TTSAdapterRegistry({})

        # Register the adapter CLASS, not an instance
        registry.register_adapter(TTSProvider.MOCK, MockAdapter)

        # Check it was registered in the adapter_classes dict
        assert TTSProvider.MOCK in registry._adapter_specs
        assert registry._adapter_specs[TTSProvider.MOCK] == MockAdapter

    @pytest.mark.asyncio
    async def test_get_adapter(self):
        """Test getting an adapter"""
        # Create registry with mock config
        config = {
            "mock_enabled": True
        }
        registry = TTSAdapterRegistry(config)

        # Register the MockAdapter class
        registry.register_adapter(TTSProvider.MOCK, MockAdapter)

        # Get the adapter (this will initialize it)
        retrieved = await registry.get_adapter(TTSProvider.MOCK)

        # Check we got an instance
        assert retrieved is not None
        assert isinstance(retrieved, MockAdapter)

    @pytest.mark.asyncio
    async def test_list_adapters(self):
        """Test listing adapters"""
        config = {
            "mock_enabled": True,
            "openai_enabled": False
        }
        registry = TTSAdapterRegistry(config)

        # Register MockAdapter
        registry.register_adapter(TTSProvider.MOCK, MockAdapter)

        # Initialize an adapter
        await registry.get_adapter(TTSProvider.MOCK)

        # Get status summary
        status = registry.get_status_summary()

        assert "providers" in status
        assert TTSProvider.MOCK.value in status["providers"]


class TestVoiceCloning:
    """Tests for voice cloning functionality"""

    @pytest.mark.asyncio
    async def test_voice_reference_in_request(self):
        """Test voice reference field in request"""
        voice_data = base64.b64encode(b"audio data").decode()

        request = TTSRequest(
            text="Test with voice cloning",
            voice="clone",
            voice_reference=voice_data
        )

        assert request.voice_reference == voice_data
        assert request.voice == "clone"

    def test_voice_reference_validation(self):
        """Test voice reference validation"""
        processor = AudioProcessor()

        # Test provider requirements
        reqs = processor.PROVIDER_REQUIREMENTS.get("higgs")
        assert reqs is not None
        assert reqs["min_duration"] == 3.0
        assert reqs["max_duration"] == 10.0
        assert reqs["preferred_sample_rate"] == 24000

        reqs = processor.PROVIDER_REQUIREMENTS.get("chatterbox")
        assert reqs is not None
        assert reqs["min_duration"] == 5.0
        assert reqs["max_duration"] == 20.0

        reqs = processor.PROVIDER_REQUIREMENTS.get("vibevoice")
        assert reqs is not None
        assert reqs["min_duration"] == 3.0
        assert reqs["max_duration"] == 30.0

    @pytest.mark.asyncio
    async def test_temp_file_cleanup(self):
        """Test temporary file cleanup for voice references"""
        # Create a temporary file
        with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
            temp_path = f.name
            f.write(b"test audio data")

        assert os.path.exists(temp_path)

        # Simulate cleanup
        try:
            # Process file
            pass
        finally:
            if os.path.exists(temp_path):
                os.remove(temp_path)

        assert not os.path.exists(temp_path)


if __name__ == "__main__":
    pytest.main([__file__, "-v"])
